import os
import random
import numpy as np


def get_txtfile(img_path="data", split=[0.6, 0.2, 0.2]):
    nums = len(os.listdir(img_path))
    idx = list(range(0, nums))

    # 对数据进行shuffle打乱
    random.shuffle(idx)
    name_file = os.listdir(img_path)
    name_file = np.array(name_file)[idx]

    # 创建并打开三个txt文件，以记录相应数据集的中数据的路径
    train_file = open('train.txt', mode="w")
    val_file = open('val.txt', mode="w")
    test_file = open('test.txt', mode="w")

    # 写入相应图片数据的路径
    for i, file in enumerate(name_file[0:int(split[0] * nums)]):
        train_file.write(str(file) + '\n')
    train_file.close()

    for i, file in enumerate(name_file[int(split[0] * nums):int((split[1] + split[0]) * nums)]):
        val_file.write(str(file) + '\n')
    val_file.close()

    for i, file in enumerate(name_file[int((split[1] + split[0]) * nums):]):
        test_file.write(str(file) + '\n')
    test_file.close()


if __name__ == "__main__":
    # path存放数据集的地址
    path = "data"
    # split中的值分别为train,val,test datasets占总数据量的比例
    split = [0.3, 0.2, 0.5]
    get_txtfile(img_path=path, split=split)
